import itertools

import numpy
import tensorflow as tf
import numpy as np
import copy
import time
from random import shuffle

import os
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
#tf.config.set_per_process_memory_growth(True)
import nalp.utils.constants as c

import pickle

import collections
from collections import defaultdict
import matplotlib.pyplot as plt

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, load_model
from keras.layers import Input, Dense, Conv2D, LeakyReLU, Dropout, Flatten, MaxPooling2D, GlobalAveragePooling2D
from keras.layers import BatchNormalization, Embedding, Reshape, Activation
from keras.layers import Concatenate, Conv2DTranspose, multiply, UpSampling2D
from keras.initializers import RandomNormal
from keras.optimizers import Adam
from keras.utils import Progbar
from keras.metrics import *
from keras import backend as K


def getdoKey(obs_Var, intv_key):
    query_str = "P(" + ",".join(x for x in obs_Var)
    if len(intv_key) != 0:
        if type(intv_key) == dict:
            query_str = query_str + "|do(" + ",".join(x for x in intv_key.keys()) + "_" + "".join(str(x) for x in intv_key.values())+")"
        else:
            query_str = query_str + "|do(" + "".join(x for x in intv_key)+")"

    query_str += ")"
    return query_str



def generate_permutations(dim_list):
    sequences=[]
    for dim in dim_list:
        sequences.append([i for i in range(dim)])

    lst = []
    for p in itertools.product(*sequences):
        lst.append(p)

    np_ara = np.array(lst)
    return np_ara

def get_joint_distributions_from_samples(observed_var, dim_list, corrensponding_samples):
    observe_perms = generate_permutations(dim_list)
    combinations,  count = np.unique(corrensponding_samples, axis=0, return_counts = True)

    upd_dist = {}
    for comb in observe_perms:
        upd_dist[tuple(list(comb))] = 1e-6

    total =corrensponding_samples.shape[0]
    for comb,cnt in zip(combinations,count):
        upd_dist[tuple(list(comb))] =  cnt/total


    return upd_dist


def calculate_TVD(dist1, dist2, doPrint):


    if len(dist1) != len(dist2):
        # raise ValueError('distribution doesnt match size')
        return 10000
    tvd =0
    for perm in dist1:
        tvd += abs(dist1[perm] - dist2[perm])
        r1 = round(dist1[perm], 3)
        r2 = round(dist2[perm], 3)

        r3 = abs(dist1[perm] - dist2[perm])
        if doPrint == True and r3 > 0.01:
            print("perm:", perm, "tvd", r3)
    return tvd*0.5



def calculate_KL(gen, real, doPrint):

    if len(real) != len(gen):
        raise ValueError('distribution doesnt match size')

    kl =0
    for perm in real:
        if real[perm]==0 or gen[perm]==0:
            continue
        kl += (real[perm])* np.log(real[perm]/(gen[perm]))
        r1 = round(real[perm], 3)
        r2 = round(gen[perm], 3)

        # r3 = real[perm]* np.log(real[perm]/(gen[perm]+1e-6))
        # if doPrint == True and r3 > 0.01:
        #     print("perm:", perm, "tvd", r3)

    # kl_pq = rel_entr(list(real.values()), list(gen.values()))
    # print('KL(P || Q): %.3f nats' % sum(kl_pq))

    # print("->", kl)
    return kl



def map_fill_to_discrete(Exp, ara, dims_list):
    each_col = []

    start,end=0,0
    for dim in dims_list:
        end=start+dim
        # indices = torch.argmax(ara[:, start: end], dim=1).view(-1,1)  # for each variable
        indices = tf.argmax(ara[:, start: end], axis=1)
        indices = tf.reshape(indices, [-1,1])
        each_col.append(indices)
        start= end


    # for id in range(int(ara.shape[1] / Exp.label_dim)):
    #     temp = ara[:, id * Exp.label_dim: (id + 1) * Exp.label_dim]
    #     indices = torch.argmax(ara[:, id * Exp.label_dim: (id + 1) * Exp.label_dim], dim=1).view(-1,1)  # for each variable
    #     each_col.append(indices)

    result= tf.concat(axis=1, values=each_col)

    return result

def map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var):
    y_dims = sum([Exp.label_dim[lb] for lb in compare_Var])
    ret = list(generated_labels_dict.values())
    generated_labels_full= tf.concat(axis=1, values=list(ret))
    generated_labels_full = tf.reshape(generated_labels_full, [-1, y_dims])
    dims_list = [Exp.label_dim[lb]for lb in compare_Var]
    generated_labels_full = numpy.ceil(map_fill_to_discrete(Exp, generated_labels_full, dims_list).numpy())

    return generated_labels_full



from PIL import Image
def load_dataset(batch_size, img_size, root, data,split):

    file_list=data['filename'].tolist()

    okay_image= []
    file_failed=0
    failed=0
    valid_id=[]
    for id, fname in enumerate(file_list):
        file_path= f'{root}/{split}/{fname}'
        # print(file_path)
        if os.path.exists(file_path)==False:
            failed+=1
            file_failed+=1
            continue


        # img= np.array(Image.open(file_path).resize((64,64))).reshape((64, 64, 1))
        img= np.array(Image.open(file_path).resize((img_size,img_size)))
        print(img.shape)

        if img.ndim>2:
            failed+=1
            continue

        # print('--->', img.shape)
        img= img.reshape((img_size, img_size, 1))
        img= np.repeat(img, 3, axis=2)
        okay_image.append(img)
        valid_id.append(id)

    print('failed', failed, 'filefailed', file_failed)
    okay_image = np.array(okay_image)
    print('Okay image:', okay_image.shape[0])


    okay_image= okay_image * (1. / 127.5) - 1

    shuff_ids=[i for i in range(len(valid_id))]
    shuffle(shuff_ids)
    okay_image = okay_image[shuff_ids]
    valid_id= [valid_id[t] for t in shuff_ids]
    print('Shuffled')

    dataset = tf.data.Dataset.from_tensor_slices(okay_image).batch(batch_size)
    # dataset = dataset.map(lambda x: (x, x))
    # print(dataset)

    return dataset, valid_id



def penalty_calculation(discriminator, X_real, G_fake):
    # Create the gradient penalty operations.
    epsilon = tf.random.uniform(shape=tf.shape(X_real), minval=0., maxval=1.)
    interpolation = epsilon * X_real + (1 - epsilon) * G_fake
    with tf.GradientTape() as pena_tape:
        pena_tape.watch(interpolation)
        penalty = (tf.norm(
            pena_tape.gradient(discriminator([interpolation]), interpolation),
            axis=1) - 1) ** 2.0

    return penalty

def apply_gumbel_softmax(output, temperature):
    uniform_dist = tf.random.uniform(tf.shape(output), 0, 1)
    gumbel_dist = -1 * tf.math.log(-1 * tf.math.log(uniform_dist + c.EPSILON) + c.EPSILON)
    axis=-1
    x = output +  gumbel_dist
    x = tf.nn.softmax(x / temperature, axis)
    y = tf.stop_gradient(tf.argmax(x, axis , tf.int32))
    return x,y




def conditional_prob(data, names, Y,X):

    # all ={**Y, **X}
    # indices = [ControllerConstants.label_names.index(lb) for lb in all]

    y_ind = [names.index(lb) for lb in Y]
    x_ind = [names.index(lb) for lb in X]

    X_values = np.array(list(X.values())).transpose()
    Y_values = np.array(list(Y.values())).transpose()

    # chosen = data[:, indices].numpy().astype(int)

    # values = np.array(list(X.values())).transpose()
    iterations = len(list(X.values()))
    save = []
    # for r in range(X_values.shape[0]):

        # c1= data[:,x_ind].numpy().astype(int)
        # c2 = X_values[r]
        # check = np.all(c1 == c2,
        #                axis=1)  # Test whether all array elements along a given axis evaluate to True
    chosen_X= data[:,x_ind]
    cond_idx = np.where(np.all(chosen_X == X_values, axis=1))

    conditioned= data[cond_idx]
    chosen_Y = conditioned[:,y_ind]
    final = np.where(np.all(chosen_Y == Y_values, axis=1))

    # cond_prb = (len(final[0])+ 10 ** -6)/(conditioned.shape[0]+ 10 ** -6)   #why division by zero occurs
    cond_prb = (len(final[0]))/(conditioned.shape[0]+ 10 ** -6)   # cant add 10 ** -6 in the numerator, cz then even if no occurrence, num/den becomes 1

    save.append(cond_prb)


    # ret= np.asarray(save)

    return save[0]




def compare_conditionals_within(Exp, dataset, observed_var, conditioning_var, names):



    dims_list1 = [Exp.label_dim[lb] for lb in observed_var]
    Yperms = generate_permutations(dims_list1)

    dims_list2 = [Exp.label_dim[lb] for lb in conditioning_var]
    X_perms = generate_permutations(dims_list2)


    cond_list=[] #each dict for different value of X
    for xp in X_perms:
        Xdict = dict(zip(conditioning_var, xp))
        dist_dict = {}
        for yp in Yperms:
            Ydict = dict(zip(observed_var, yp))
            # YXdict={**Ydict, **Xdict}

            dist_dict[tuple(Ydict.values())]= conditional_prob(dataset, names, Ydict, Xdict)

        cond_list.append(dist_dict)

    # print("distribution", dist_dict)

    return cond_list


